Skip to content

Conversation

@OutisLi
Copy link
Collaborator

@OutisLi OutisLi commented Jan 12, 2026

Summary by CodeRabbit

  • New Features
    • Added HybridMuon optimizer option with configurable hyperparameters (momentum alias, lr_adjust, lr_adjust_coeff, min_2d_dim, muon_2d_only, weight decay) and integrated training support (initialization, LR tracking, scheduler compatibility, checkpoint resume).
  • Tests
    • Added comprehensive tests covering orthogonalization, optimizer step behavior, weight-decay effects, fallback cases, LR-adjust modes, and state save/load.

✏️ Tip: You can customize this high-level summary in your review settings.

Copilot AI review requested due to automatic review settings January 12, 2026 05:11
@dosubot dosubot bot added the new feature label Jan 12, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 12, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Adds a new HybridMuonOptimizer (with Newton–Schulz zeropower), exposes it in the public API, integrates it into training/config validation, and introduces unit tests for routing, updates, dtype handling, and state persistence.

Changes

Cohort / File(s) Summary
Optimizer Implementation
deepmd/pt/optimizer/hybrid_muon.py
New HybridMuonOptimizer class, Newton–Schulz zeropower helpers, routing logic (muon vs adam paths), per-parameter state, bf16/float handling, compile wrapper and constants.
Public API
deepmd/pt/optimizer/__init__.py
Import and export HybridMuonOptimizer via __all__.
Training Integration
deepmd/pt/train/training.py
Import HybridMuonOptimizer; extend get_opt_param() with KF/LKF params; initialize "HybridMuon" optimizer; include in LR tracking, step flow, and checkpoint state load/save.
Argument Schema
deepmd/utils/argcheck.py
Add HybridMuon variant with hyperparameters (momentum/muon_momentum, adam_beta1, adam_beta2, weight_decay, lr_adjust, lr_adjust_coeff, muon_2d_only, min_2d_dim); add alias muon_momentum.
Tests
source/tests/pt/test_hybrid_muon.py
New tests for Newton–Schulz orthogonalization, BF16 support gating, optimizer routing/updates, weight decay, lr_adjust modes, and state_dict persistence.

Sequence Diagram(s)

sequenceDiagram
    participant Trainer
    participant HybridMuonOptimizer
    participant Router
    participant AdamPath
    participant MuonPath

    Trainer->>HybridMuonOptimizer: step(grads)
    HybridMuonOptimizer->>Router: classify params (first step)
    Router-->>HybridMuonOptimizer: partitions (adam_1d, adam_matrix, muon_params)
    HybridMuonOptimizer->>AdamPath: update 1D params (exp_avg, exp_avg_sq)
    HybridMuonOptimizer->>AdamPath: update small 2D matrices (Adam fallback)
    HybridMuonOptimizer->>MuonPath: update large 2D matrices (momentum, NS orthogonalize)
    MuonPath->>MuonPath: apply lr_adjust, weight decay, per-bucket scaling
    HybridMuonOptimizer-->>Trainer: updated params / state
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • wanghan-iapcm
  • iProzd
🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 72.34% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat(pt): add HybridMuonOptimizer' clearly and concisely summarizes the main change: adding a new HybridMuonOptimizer to the PyTorch backend, which is the primary focus across all modified files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @deepmd/pt/optimizer/muon.py:
- Around line 466-544: raw_deltas are float32 but adam_matrix_params may be
bf16/fp16 causing dtype-mismatch on torch._foreach_add_; before calling
torch._foreach_add_ convert each delta to its corresponding param's dtype/device
(e.g. delta = delta.to(param.dtype).to(param.device) or delta =
delta.type_as(param))—do this after clipping/scaling and then call
torch._foreach_add_(adam_matrix_params, casted_deltas) so updates match
parameter dtypes.

In @source/tests/pt/test_muon.py:
- Around line 21-56: The test should skip on devices that don't support bfloat16
matmul: in test_orthogonalization (and optionally test_shape_and_dtype) probe
bf16 matmul support by attempting a small BF16 matmul on self.device inside a
try/except (e.g., create two tiny tensors with dtype=torch.bfloat16 and call
.matmul or torch.matmul) and call self.skipTest with an explanatory message if
it raises or is unsupported; use self.skipTest rather than asserting so CI
quietly skips environments where zeropower_via_newtonschulz5's BF16 path cannot
run reliably.
🧹 Nitpick comments (3)
deepmd/utils/argcheck.py (1)

3399-3526: Config surface matches the runtime wiring; minor doc/UX nits.

  • Good: Muon exposes min_2d_dim, lr_adjust, lr_adjust_coeff, and Adam betas consistent with deepmd/pt/train/training.py.
  • Consider clarifying in docs that muon_momentum is an alias for both AdaMuon/Muon within their respective opt_type blocks (to reduce confusion).
source/tests/pt/test_muon.py (1)

80-82: Prefer zip(..., strict=True) in tests. Avoids silently ignoring length mismatches.

Proposed tweak
-        for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params)):
+        for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params, strict=True)):
             self.assertFalse(torch.allclose(p, init_p), f"Parameter {i} did not change")
deepmd/pt/optimizer/muon.py (1)

77-89: Consider wrapping torch.compile() in try/except for robustness on unsupported devices or graph patterns. While the Newton-Schulz functions use standard operations that should compile reliably, adding a fallback allows graceful degradation if compilation fails on certain hardware configurations or edge cases:

try:
    return torch.compile(fn, fullgraph=True, dynamic=True)
except Exception:
    return fn

This is particularly useful since fullgraph=True can fail hard rather than gracefully degrade. Given PyTorch 2.7+ is required (where torch.compile is stable), the redundant hasattr(torch, "compile") check can be simplified or removed.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 82a5f32 and c389ffc.

📒 Files selected for processing (5)
  • deepmd/pt/optimizer/__init__.py
  • deepmd/pt/optimizer/muon.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_muon.py
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:40-109
Timestamp: 2026-01-10T04:28:11.377Z
Learning: The coefficients (3.4445, -4.7750, 2.0315) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` are standard across AdaMuon implementations and should not be modified for consistency with the reference implementations.
📚 Learning: 2026-01-10T04:28:11.377Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:40-109
Timestamp: 2026-01-10T04:28:11.377Z
Learning: The coefficients (3.4445, -4.7750, 2.0315) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` are standard across AdaMuon implementations and should not be modified for consistency with the reference implementations.

Applied to files:

  • deepmd/pt/optimizer/muon.py
  • deepmd/pt/train/training.py
  • source/tests/pt/test_muon.py
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.

Applied to files:

  • deepmd/pt/train/training.py
🧬 Code graph analysis (3)
deepmd/pt/optimizer/muon.py (2)
deepmd/pt/train/training.py (1)
  • step (799-1171)
source/tests/pt/test_adamuon.py (1)
  • closure (328-334)
deepmd/pt/train/training.py (1)
deepmd/pt/optimizer/muon.py (1)
  • MuonOptimizer (226-646)
source/tests/pt/test_muon.py (1)
deepmd/pt/optimizer/muon.py (3)
  • MuonOptimizer (226-646)
  • zeropower_via_newtonschulz5 (159-186)
  • step (379-646)
🪛 Ruff (0.14.10)
deepmd/pt/optimizer/muon.py

186-186: Avoid specifying long messages outside the exception class

(TRY003)


216-216: Avoid specifying long messages outside the exception class

(TRY003)


308-308: Avoid specifying long messages outside the exception class

(TRY003)

source/tests/pt/test_muon.py

80-80: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

🔇 Additional comments (6)
deepmd/pt/optimizer/__init__.py (1)

11-15: Public export wiring looks correct. MuonOptimizer is imported and added to __all__, enabling from deepmd.pt.optimizer import MuonOptimizer.

deepmd/pt/train/training.py (4)

44-49: Import + optimizer exposure is consistent with new public API.


159-177: Nice hardening: provide LKF defaults + Muon knobs in get_opt_param. Prevents missing-key crashes when LKF-specific keys aren’t present.


720-754: Muon optimizer integration looks correct; ensure state dict restore is covered. min_2d_dim is passed only for Muon (as expected), and scheduler + resume path are consistent with Adam/AdaMuon branches.


823-829: LR display/scheduler path correctly includes "Muon".

deepmd/pt/optimizer/muon.py (1)

65-75: Keep NS coefficients as-is (they match standard references). NS_COEFF_A/B/C match the canonical (3.4445, -4.7750, 2.0315). Based on learnings.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds a new Muon optimizer to the DeePMD-kit PyTorch backend. Muon applies Newton-Schulz orthogonalization to gradients before using momentum, resulting in orthogonalized updates for weight matrices. The optimizer uses different update strategies based on parameter dimensionality: Muon with Newton-Schulz for >=2D parameters, and Adam for 1D parameters (biases, norms).

Changes:

  • Implemented MuonOptimizer with Newton-Schulz orthogonalization algorithm
  • Added configuration support for Muon optimizer parameters in argcheck
  • Integrated Muon optimizer into the training pipeline with scheduler support

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
deepmd/pt/optimizer/muon.py New file implementing MuonOptimizer with Newton-Schulz orthogonalization for >=2D params and Adam for 1D params
deepmd/pt/optimizer/init.py Added MuonOptimizer to module exports
deepmd/utils/argcheck.py Added Muon optimizer configuration with parameters for momentum, Adam betas, weight decay, lr_adjust, and min_2d_dim
deepmd/pt/train/training.py Integrated Muon optimizer initialization, parameter extraction, and scheduler setup; minor string formatting improvements
source/tests/pt/test_muon.py Comprehensive test suite covering Newton-Schulz orthogonalization, optimizer step behavior, parameter routing, weight decay, lr_adjust modes, and state dict handling

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@codecov
Copy link

codecov bot commented Jan 12, 2026

Codecov Report

❌ Patch coverage is 79.18089% with 61 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.93%. Comparing base (567c5ba) to head (6109e6e).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/pt/optimizer/hybrid_muon.py 79.51% 59 Missing ⚠️
deepmd/pt/train/training.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5149      +/-   ##
==========================================
- Coverage   81.95%   81.93%   -0.02%     
==========================================
  Files         713      714       +1     
  Lines       72985    73277     +292     
  Branches     3617     3617              
==========================================
+ Hits        59812    60043     +231     
- Misses      12010    12072      +62     
+ Partials     1163     1162       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
deepmd/pt/train/training.py (2)

720-753: AdaMuon does not initialize self.scheduler but step() requires it.

The AdaMuon branch (lines 720-732) creates the optimizer without initializing self.scheduler, but the step() method at lines 824 and 841 assumes self.scheduler exists for both AdaMuon and HybridMuon. This will cause an AttributeError at runtime when using AdaMuon.

Add scheduler initialization to the AdaMuon branch:

if optimizer_state_dict is not None and self.restart_training:
    self.optimizer.load_state_dict(optimizer_state_dict)
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
    self.optimizer,
    lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
)

159-177: Config-schema mismatch: LKF preference parameters are code-accessible but not schema-exposed.

get_opt_param() reads kf_start_pref_e, kf_limit_pref_e, kf_start_pref_f, and kf_limit_pref_f (lines 162–165) and these are actively used to calculate preference weights during training (in the step method for EnergyStdLoss). However, the LKF schema in deepmd/utils/argcheck.py only defines kf_blocksize. Users cannot configure these preference parameters via the config file because schema validation will not recognize them.

🤖 Fix all issues with AI agents
In @source/tests/pt/test_hybrid_muon.py:
- Around line 107-109: The test loop uses zip(model.parameters(),
initial_params) which can silently truncate if lengths differ; update the loop
in test_hybrid_muon.py to use zip(model.parameters(), initial_params,
strict=True) so mismatched lengths raise an error (requires Python 3.10+),
keeping the same enumerate and assertion (i, p, init_p identifiers unchanged).
🧹 Nitpick comments (5)
deepmd/pt/optimizer/hybrid_muon.py (4)

88-100: Fix type annotations: callable is not a type.
Using the builtin callable in annotations is non-idiomatic and breaks type checking; use collections.abc.Callable (or typing.Callable) instead.

Proposed diff
@@
-from typing import (
+from typing import (
     TYPE_CHECKING,
     Any,
 )
@@
-if TYPE_CHECKING:
+if TYPE_CHECKING:
     from collections.abc import (
+        Callable,
         Iterable,
     )
@@
 def _maybe_compile(
-    fn: callable,
-) -> callable:
+    fn: "Callable[..., Any]",
+) -> "Callable[..., Any]":
@@
 def step(
     self,
-    closure: callable | None = None,
+    closure: "Callable[[], torch.Tensor]" | None = None,
 ) -> torch.Tensor | None:

Also applies to: 389-393


76-86: Consider aligning NS_EPS with AdaMuon’s Newton–Schulz epsilon for consistency.
Coefficients match the standard (good), but NS_EPS=1e-7 differs from the AdaMuon implementation used elsewhere in this repo (and the learned “don’t change eps” guidance). If the change is intentional, a short comment explaining why HybridMuon needs a different epsilon would help. Based on learnings, keep constants consistent unless there’s a measured reason.

Also applies to: 120-128, 154-162


335-388: Static routing: please document that routing won’t change after the first step().
This is probably fine, but it’s worth calling out explicitly because parameter freezing/unfreezing or adding param groups mid-training won’t be reflected after _routing_built=True.


477-555: Small-2D Adam fallback: avoid hard-coded magic caps (or expose them).
max_rel_change=0.05, abs_floor=1e-3*sqrt(numel), and min(lr_adjust_coeff, 0.1) are important stability knobs but currently “hidden”. Consider making them constants (at least) or args so behavior is tunable and testable.

source/tests/pt/test_hybrid_muon.py (1)

15-38: BF16 gating is reasonable, but import-time probing can be a little heavy.
Not a blocker, but if this ever becomes flaky on CI, consider moving the probe into setUpClass to avoid side effects on import.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c389ffc and 1978c7f.

📒 Files selected for processing (5)
  • deepmd/pt/optimizer/__init__.py
  • deepmd/pt/optimizer/hybrid_muon.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_hybrid_muon.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2026-01-10T04:28:18.703Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:40-109
Timestamp: 2026-01-10T04:28:18.703Z
Learning: The coefficients (3.4445, -4.7750, 2.0315) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` are standard across AdaMuon implementations and should not be modified for consistency with the reference implementations.

Applied to files:

  • deepmd/pt/optimizer/hybrid_muon.py
  • source/tests/pt/test_hybrid_muon.py
  • deepmd/pt/train/training.py
📚 Learning: 2026-01-10T04:29:25.299Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:44-113
Timestamp: 2026-01-10T04:29:25.299Z
Learning: The epsilon parameter (eps=1e-8) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` follows the standard PyTorch AdaMuon reference implementations (which typically use eps=1e-7) and should not be modified for consistency with reference implementations, despite being smaller than bfloat16's machine epsilon.

Applied to files:

  • deepmd/pt/optimizer/hybrid_muon.py
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.

Applied to files:

  • deepmd/pt/train/training.py
🧬 Code graph analysis (4)
deepmd/pt/optimizer/hybrid_muon.py (1)
source/tests/pt/test_adamuon.py (1)
  • closure (328-334)
source/tests/pt/test_hybrid_muon.py (2)
deepmd/pt/optimizer/hybrid_muon.py (1)
  • zeropower_via_newtonschulz5 (170-197)
source/install/build_tf.py (1)
  • env (829-839)
deepmd/pt/optimizer/__init__.py (1)
deepmd/pt/optimizer/hybrid_muon.py (1)
  • HybridMuonOptimizer (237-657)
deepmd/pt/train/training.py (1)
deepmd/pt/optimizer/hybrid_muon.py (1)
  • HybridMuonOptimizer (237-657)
🪛 Ruff (0.14.10)
deepmd/pt/optimizer/hybrid_muon.py

197-197: Avoid specifying long messages outside the exception class

(TRY003)


227-227: Avoid specifying long messages outside the exception class

(TRY003)


319-319: Avoid specifying long messages outside the exception class

(TRY003)

source/tests/pt/test_hybrid_muon.py

32-32: Consider moving this statement to an else block

(TRY300)


107-107: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (6)
deepmd/utils/argcheck.py (2)

3399-3407: AdaMuon momentum alias looks good for backward compatibility.
Adding alias=["muon_momentum"] should help config migrations without breaking existing momentum usage.


3454-3532: Documentation claim about torch.optim.Muon is accurate—no changes needed.

PyTorch 2.7+ (the minimum supported version for GPU, and 2.8+ for CPU per pyproject.toml) includes torch.optim.Muon, which is indeed designed to optimize only 2D parameters via Newton-Schulz orthogonalization. The doc statement correctly distinguishes DeePMD's HybridMuon (supporting both 1D and 2D parameters) from PyTorch's Muon (2D only), and is not misleading.

deepmd/pt/optimizer/__init__.py (1)

1-20: Public export wiring looks correct.
Import + __all__ update is consistent with training.py usage.

deepmd/pt/train/training.py (2)

733-746: HybridMuon integration looks consistent with the new optimizer interface.
Passing lr_adjust/lr_adjust_coeff/min_2d_dim matches HybridMuonOptimizer.__init__, and state_dict loading mirrors Adam/AdaMuon flows.


639-641: Logging format changes look fine.
The single-line warning/info logs are clearer and keep log parsing simple.

Also applies to: 1564-1566

source/tests/pt/test_hybrid_muon.py (1)

40-228: Test coverage looks solid for routing + state_dict round-trip.
The separation checks (momentum_buffer vs exp_avg/exp_avg_sq) and min_2d_dim fallback test should catch most integration regressions.

@OutisLi OutisLi requested a review from njzjz January 13, 2026 03:28
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @deepmd/pt/optimizer/hybrid_muon.py:
- Around line 684-724: The muon update computes an `orth` delta (via
_zeropower_via_newtonschulz5_2d and _zeropower_via_newtonschulz5_3d) in bf16 but
applies it in-place to params of potentially different dtypes causing runtime
errors; cast the delta(s) to the target parameter dtype before calling in-place
add_. Specifically, in the single-matrix branch cast `orth`/`delta` to
`entry["param"].dtype` before `entry["param"].add_`, and in the batched branch
cast each `orth[i]` (or the slice used) to `params[i].dtype` (or call
.to(params[i].dtype)) before `params[i].add_`.
- Around line 77-99: The _maybe_compile function currently calls
torch.compile(fn, fullgraph=True, dynamic=True) at import time and can raise
exceptions that break imports; wrap the torch.compile call in a try/except that
catches Exception (or RuntimeError) and returns the original fn on any
compilation failure, while preserving the existing default_device check and
behavior; ensure the fallback logs or silently ignores the compile error and
returns fn so that _maybe_compile and functions wrapped by it (e.g., any callers
of _maybe_compile) remain usable when torch.compile is unavailable or fails.
🧹 Nitpick comments (2)
deepmd/utils/argcheck.py (1)

3454-3542: HybridMuon arg schema matches the intended routing knobs; one doc nit: “Nesterov” formula wording is slightly off.
You describe m_t = beta*m_{t-1} + (1-beta)*g_t (momentum EMA) and later apply a Nesterov-style lookahead; consider rewording to avoid implying the EMA itself is “Nesterov”.

deepmd/pt/train/training.py (1)

159-178: get_opt_param() now always includes LKF + Muon knobs; please sanity-check you don’t silently accept misspelled keys.
Not a blocker, but when configs evolve, it’s easy to carry dead/typoed fields unnoticed.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1978c7f and 74cf72d.

📒 Files selected for processing (3)
  • deepmd/pt/optimizer/hybrid_muon.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2026-01-10T04:28:18.703Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:40-109
Timestamp: 2026-01-10T04:28:18.703Z
Learning: The coefficients (3.4445, -4.7750, 2.0315) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` are standard across AdaMuon implementations and should not be modified for consistency with the reference implementations.

Applied to files:

  • deepmd/pt/train/training.py
  • deepmd/pt/optimizer/hybrid_muon.py
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.

Applied to files:

  • deepmd/pt/train/training.py
📚 Learning: 2026-01-10T04:29:25.299Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:44-113
Timestamp: 2026-01-10T04:29:25.299Z
Learning: The epsilon parameter (eps=1e-8) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` follows the standard PyTorch AdaMuon reference implementations (which typically use eps=1e-7) and should not be modified for consistency with reference implementations, despite being smaller than bfloat16's machine epsilon.

Applied to files:

  • deepmd/pt/optimizer/hybrid_muon.py
🧬 Code graph analysis (2)
deepmd/pt/train/training.py (1)
deepmd/pt/optimizer/hybrid_muon.py (1)
  • HybridMuonOptimizer (236-725)
deepmd/pt/optimizer/hybrid_muon.py (2)
deepmd/pt/train/training.py (1)
  • step (801-1173)
source/tests/pt/test_adamuon.py (1)
  • closure (328-334)
🪛 Ruff (0.14.10)
deepmd/pt/optimizer/hybrid_muon.py

196-196: Avoid specifying long messages outside the exception class

(TRY003)


226-226: Avoid specifying long messages outside the exception class

(TRY003)


324-324: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
🔇 Additional comments (7)
deepmd/utils/argcheck.py (1)

3396-3450: AdaMuon: muon_momentum alias looks good; please ensure it’s documented consistently in user docs/examples.
This improves backward/forward compatibility for config naming (esp. when users think in “Muon momentum” rather than generic “momentum”).

deepmd/pt/optimizer/hybrid_muon.py (3)

101-197: Newton–Schulz kernels: coeffs look consistent with reference implementations (good).
One thing to verify: using a single EPS = 1e-7 for both NS norm clamp and Adam epsilon may be intentional, but it’s a behavior change vs AdaMuon-style eps choices. Based on learnings, keep coeffs unchanged; please double-check eps expectations.


341-407: Static routing is sensible, but it assumes parameter ranks/shapes stay stable across training.
That’s probably fine for this codebase; just be aware it will misroute if someone swaps modules/params mid-training (rare).


443-627: The implementation looks sound. The bias-correction formula matches PyTorch's Adam (state["beta1_pow"] and state["beta2_pow"] are cumulative products β^t, giving bias_corr = 1 − β^t), and there's existing test coverage for state_dict roundtrip in TestHybridMuonOptimizerStateDict.test_state_dict_save_load. The use of torch._foreach_* is acceptable for the supported torch version range (≥2.7), and this pattern is established elsewhere in the codebase (e.g., adamuon.py). No action needed.

deepmd/pt/train/training.py (3)

44-49: Training import + optimizer registry wiring for HybridMuonOptimizer looks consistent.


825-844: Step-loop inclusion (HybridMuon alongside Adam/AdaMuon) is correct for scheduler LR tracking.


734-755: HybridMuon init matches argcheck + optimizer signature; muon_2d_only and min_2d_dim are documented in class docstring and argcheck with test examples provided. Remove the unrelated bf16 dtype query from this review.

@OutisLi OutisLi changed the title feat(pt): add Muon optimizer feat(pt): add HybridMuonOptimizer Jan 13, 2026
fix(pt): Muon bug fix

feat&fix(pt): Muon add bf16 support

feat(pt): use tf32 for Muon

fix(pt): Use 1e-8 for Muon

feat(pt): Update Muon

fix(pt): use the same lr for adam inside Muon

feat(pt): add match_rms for Muon

feat(pt): adjust Muon

feat(pt): Update Muon

(cherry picked from commit 9b4e63d)
(cherry picked from commit 46fcb7d)
(cherry picked from commit 1dd737f)
Changes:
1. Remove dtype conversion: NS output (bfloat16) now directly applied to
   parameters, matching torch.optim.Muon behavior where PyTorch handles
   mixed precision automatically.

2. Add muon_2d_only parameter (default True): When True, only 2D parameters
   use Muon; >2D parameters use Adam without weight decay. This matches
   PyTorch's official torch.optim.Muon which only supports 2D matrices.

3. Merge NS_EPS and ADAM_EPS into single EPS constant (both 1e-7).

4. Update dtype documentation to reflect actual behavior:
   - NS output (bfloat16) directly applied to parameters
   - Muon momentum buffer follows gradient dtype (not param dtype)

5. Update weight_decay docstring from ">=2D params" to "Muon-routed
   parameters" for accuracy with muon_2d_only=True.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In @deepmd/pt/optimizer/hybrid_muon.py:
- Around line 87-99: The _maybe_compile function currently calls torch.compile
directly which can raise exceptions; update it to call torch.compile(fn,
fullgraph=True, dynamic=True) inside a try/except that catches Exception (or
RuntimeError) and on failure logs or silently ignores the error and returns the
original fn so compilation failures fall back to eager execution; reference the
_maybe_compile function and torch.compile call and ensure the except path
returns fn (and optionally logs the exception via the module logger or
warnings).
- Around line 443-724: The foreach calls fail on heterogeneous devices/dtypes;
bucket tensors by (device,dtype) before calling torch._foreach_* and run the
foreach per-bucket (or fallback to per-parameter ops) for muon_params_for_decay
(torch._foreach_mul_), for muon_grads and muon_momentum_buffers
(torch._foreach_lerp_ and torch._foreach_lerp) and for
adam_matrix_params/raw_deltas when calling torch._foreach_norm and torch.stack;
specifically, group lists into buckets keyed by (tensor.device, tensor.dtype)
(like the existing buckets for Newton-Schulz), then call the corresponding
torch._foreach_* on each bucket's sublists and replace the original single-call
sites (references: muon_params_for_decay, muon_grads, muon_momentum_buffers,
adam_matrix_params, raw_deltas, and the torch._foreach_* invocations) so every
foreach sees homogeneous device/dtype inputs and you avoid mixed-precision
crashes.
🧹 Nitpick comments (5)
deepmd/pt/optimizer/hybrid_muon.py (1)

77-85: Align/justify EPS across AdaMuon vs HybridMuon (currently 1e-7 here).
You kept the standard NS coefficients (good). EPS differs from deepmd/pt/optimizer/adamuon.py (which historically used 1e-8 per repo learnings). If the difference is intentional (e.g., matching official Muon), please add a short comment explaining why HybridMuon diverges so users don’t “fix” it later. Based on learnings, the coefficients should stay unchanged.

Also applies to: 119-121

deepmd/pt/train/training.py (1)

159-178: Make get_opt_param() resilient to alias keys (if configs bypass normalization).
Argcheck introduces aliases like muon_momentum / muon_min_2d_dim, but get_opt_param() reads only momentum / min_2d_dim. If a caller provides raw configs without running the normalizer, HybridMuon/AdaMuon may silently ignore the alias values.

Proposed fix
-                "momentum": params.get("momentum", 0.95),
+                "momentum": params.get("momentum", params.get("muon_momentum", 0.95)),
@@
-                "min_2d_dim": params.get("min_2d_dim", 1),
+                "min_2d_dim": params.get("min_2d_dim", params.get("muon_min_2d_dim", 1)),
deepmd/utils/argcheck.py (1)

3420-3476: Docs: mention lr_adjust_coeff also affects small-2D Adam fallback (current implementation).
In HybridMuonOptimizer, lr_adjust_coeff is “dual-purpose” (match-RMS scaling and matrix-fallback LR scaling via min(lr_adjust_coeff, 0.1)). The schema doc currently describes only match-RMS scaling. Consider adding one sentence to prevent surprise when users tune it.

Also applies to: 3478-3566

source/tests/pt/test_hybrid_muon.py (2)

29-34: Consider moving the success return to an else block.

The linter suggests placing the success path in an else block for clearer control flow separation between success and error handling.

♻️ Suggested refactor
     try:
         a = torch.randn(4, 4, dtype=torch.bfloat16, device=device)
         _ = torch.mm(a, a.T)
-        return True
     except (RuntimeError, TypeError):
         return False
+    else:
+        return True

107-108: Add strict=True to zip() for defensive checking.

While both iterables originate from the same model making length mismatch unlikely, adding strict=True provides early failure detection if the iteration logic ever changes.

♻️ Suggested fix
-        for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params)):
+        for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params, strict=True)):
             self.assertFalse(torch.allclose(p, init_p), f"Parameter {i} did not change")
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 74cf72d and d3a5abf.

📒 Files selected for processing (5)
  • deepmd/pt/optimizer/__init__.py
  • deepmd/pt/optimizer/hybrid_muon.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_hybrid_muon.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2026-01-10T04:28:18.703Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:40-109
Timestamp: 2026-01-10T04:28:18.703Z
Learning: The coefficients (3.4445, -4.7750, 2.0315) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` are standard across AdaMuon implementations and should not be modified for consistency with the reference implementations.

Applied to files:

  • deepmd/pt/optimizer/hybrid_muon.py
  • source/tests/pt/test_hybrid_muon.py
  • deepmd/pt/train/training.py
📚 Learning: 2026-01-10T04:29:25.299Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:44-113
Timestamp: 2026-01-10T04:29:25.299Z
Learning: The epsilon parameter (eps=1e-8) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` follows the standard PyTorch AdaMuon reference implementations (which typically use eps=1e-7) and should not be modified for consistency with reference implementations, despite being smaller than bfloat16's machine epsilon.

Applied to files:

  • deepmd/pt/optimizer/hybrid_muon.py
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.

Applied to files:

  • deepmd/pt/train/training.py
🧬 Code graph analysis (4)
deepmd/pt/optimizer/__init__.py (3)
deepmd/pt/optimizer/hybrid_muon.py (1)
  • HybridMuonOptimizer (236-725)
deepmd/pt/optimizer/KFWrapper.py (1)
  • KFOptimizerWrapper (13-140)
deepmd/pt/optimizer/LKF.py (1)
  • LKFOptimizer (30-325)
deepmd/pt/optimizer/hybrid_muon.py (2)
deepmd/pt/train/training.py (1)
  • step (819-1191)
source/tests/pt/test_adamuon.py (1)
  • closure (330-336)
source/tests/pt/test_hybrid_muon.py (1)
deepmd/pt/optimizer/hybrid_muon.py (3)
  • HybridMuonOptimizer (236-725)
  • zeropower_via_newtonschulz5 (169-196)
  • step (409-725)
deepmd/pt/train/training.py (1)
deepmd/pt/optimizer/hybrid_muon.py (1)
  • HybridMuonOptimizer (236-725)
🪛 Ruff (0.14.10)
deepmd/pt/optimizer/hybrid_muon.py

196-196: Avoid specifying long messages outside the exception class

(TRY003)


226-226: Avoid specifying long messages outside the exception class

(TRY003)


324-324: Avoid specifying long messages outside the exception class

(TRY003)

source/tests/pt/test_hybrid_muon.py

32-32: Consider moving this statement to an else block

(TRY300)


107-107: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (20)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
🔇 Additional comments (8)
deepmd/pt/optimizer/__init__.py (1)

5-20: LGTM: clean public API export for HybridMuonOptimizer.

deepmd/pt/train/training.py (1)

752-773: HybridMuon integration looks consistent with Adam/AdaMuon scheduling and stepping.
Instantiation wiring (betas, lr_adjust*, muon_2d_only, min_2d_dim) + inclusion in the Adam-family step path is coherent.

Also applies to: 843-861

deepmd/utils/argcheck.py (1)

3420-3431: Alias normalization is already handled correctly by the dargs framework.

The config is explicitly normalized via deepmd.utils.argcheck.normalize() in the entrypoint (deepmd/pt/entrypoints/main.py) before Trainer instantiation. This function uses dargs' normalize_value() method, which automatically handles alias-to-canonical key mapping for nested Argument and Variant schemas recursively. The optimizer code in training.py accesses parameters using only canonical names ("momentum", "min_2d_dim"), not the alias names. All entrypoint code paths follow the normalize → Trainer flow, so un-normalized config reaching the optimizer is not a concern.

source/tests/pt/test_hybrid_muon.py (5)

1-13: LGTM!

License header and imports are appropriate for the test module.


40-82: LGTM!

The Newton-Schulz orthogonalization tests provide good coverage:

  • Orthogonality verification with appropriate tolerances for bf16 precision.
  • Shape and dtype preservation for 2D and 3D inputs.
  • Error handling for invalid 1D input.

110-191: LGTM!

The optimizer tests provide solid coverage of key behaviors:

  • Weight decay effect verification.
  • Correct routing of 2D params to Muon (momentum_buffer) and 1D to Adam (exp_avg/exp_avg_sq).
  • Adam fallback for small 2D matrices based on min_2d_dim.
  • Different lr_adjust modes producing distinct updates.

Fixed seeds ensure reproducibility.


194-227: LGTM!

State dict save/load test correctly verifies:

  • State is properly serialized and deserialized.
  • Both tensor values (via torch.allclose) and scalar values are compared.
  • Per-parameter state integrity is maintained.

230-231: LGTM!

Standard test entry point.

@OutisLi OutisLi requested a review from njzjz January 14, 2026 05:15
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pt/train/training.py (1)

739-772: Add scheduler setup for AdaMuon optimizer.

The AdaMuon block (lines 739-751) lacks scheduler initialization, yet line 843 includes "AdaMuon" in the condition that calls self.scheduler.get_last_lr()[0] at line 844. This will cause an AttributeError at runtime.

Add the missing scheduler initialization following the HybridMuon pattern (lines 768-772):

if optimizer_state_dict is not None and self.restart_training:
    self.optimizer.load_state_dict(optimizer_state_dict)
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
    self.optimizer,
    lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
)

Insert this block immediately after the AdaMuon optimizer initialization (after line 751).

🧹 Nitpick comments (2)
source/tests/pt/test_hybrid_muon.py (2)

15-34: LGTM! Robust BF16 support detection.

The function handles CUDA capability checks gracefully and falls back to a practical matmul test for CPU. This addresses the CI robustness concern from previous reviews.

Minor style note: the static analyzer suggests moving return True (line 32) to an else block for clarity, but this is optional.

♻️ Optional: Move return to else block per TRY300
     try:
         a = torch.randn(4, 4, dtype=torch.bfloat16, device=device)
         _ = torch.mm(a, a.T)
-        return True
     except (RuntimeError, TypeError):
         return False
+    else:
+        return True

84-193: Consider adding test coverage for muon_2d_only=False.

All tests use the default muon_2d_only=True. Adding a test with muon_2d_only=False would verify that ≥2D (e.g., 3D) parameters use Muon instead of Adam, exercising that routing path.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d3a5abf and 6109e6e.

📒 Files selected for processing (2)
  • deepmd/pt/train/training.py
  • source/tests/pt/test_hybrid_muon.py
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:40-109
Timestamp: 2026-01-10T04:28:18.703Z
Learning: The coefficients (3.4445, -4.7750, 2.0315) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` are standard across AdaMuon implementations and should not be modified for consistency with the reference implementations.
📚 Learning: 2026-01-10T04:28:18.703Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:40-109
Timestamp: 2026-01-10T04:28:18.703Z
Learning: The coefficients (3.4445, -4.7750, 2.0315) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` are standard across AdaMuon implementations and should not be modified for consistency with the reference implementations.

Applied to files:

  • source/tests/pt/test_hybrid_muon.py
  • deepmd/pt/train/training.py
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.

Applied to files:

  • deepmd/pt/train/training.py
🧬 Code graph analysis (1)
deepmd/pt/train/training.py (1)
deepmd/pt/optimizer/hybrid_muon.py (1)
  • HybridMuonOptimizer (236-725)
🪛 Ruff (0.14.11)
source/tests/pt/test_hybrid_muon.py

32-32: Consider moving this statement to an else block

(TRY300)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (cuda120, cuda)
🔇 Additional comments (9)
deepmd/pt/train/training.py (5)

44-49: LGTM!

The HybridMuonOptimizer import is correctly added alongside existing optimizer imports.


159-178: LGTM!

The get_opt_param function correctly extracts parameters for the HybridMuon optimizer, including muon_2d_only and min_2d_dim which are specific to HybridMuon as indicated in the external code snippet.


843-843: LGTM!

Correctly includes HybridMuon in the optimizer types that use the LR scheduler.


656-658: LGTM!

The log message consolidation to single-line format improves readability.


1584-1586: LGTM!

The log message format change is consistent with the other logging updates in this PR.

source/tests/pt/test_hybrid_muon.py (4)

40-82: LGTM!

The Newton-Schulz orthogonalization tests appropriately validate:

  • Approximate orthogonality with tolerances suitable for bf16 precision
  • Shape preservation across 2D and batched 3D inputs
  • Error handling for invalid 1D input

Based on learnings, the underlying NS coefficients (3.4445, -4.7750, 2.0315) are standard and should not be modified.


84-143: LGTM!

Good coverage of optimizer fundamentals:

  • test_step validates parameters change after optimization
  • test_weight_decay confirms decay reduces norms
  • test_muon_adam_separation verifies the routing logic (Muon for 2D weights, Adam for 1D biases)

The strict=True in zip() at line 108 addresses the previous review feedback.


145-193: LGTM!

These tests validate key HybridMuon behaviors:

  • test_muon_adam_fallback_small_2d: Correctly tests the min_2d_dim threshold causing fallback to Adam
  • test_lr_adjust_modes: Verifies that different lr_adjust values produce different update behaviors

196-229: LGTM!

The state dict persistence test properly verifies that optimizer state survives save/load cycles by comparing per-parameter state tensors.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants